未验证 提交 81aeea36 编写于 作者: J Jeff Rasley 提交者: GitHub

Elastic training support (#602)

Co-authored-by: NSamyam Rajbhandari <samyamr@microsoft.com>
上级 7435b2f1
......@@ -4,14 +4,12 @@ name: Build
# Controls when the action will run.
on:
# Triggers the workflow on push or pull request events but only for the master branch
push:
branches: [ master ]
paths-ignore:
- 'docs/**'
pull_request:
branches: [ master ]
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
paths-ignore:
- 'docs/**'
# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
......
#!/usr/bin/env python
import argparse
import json
import deepspeed
from deepspeed.elasticity import compute_elastic_config
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json")
parser.add_argument('-w', '--world-size', type=int, default=0, help="Intended/current world size")
args = parser.parse_args()
ds_config = json.load(open(args.config, 'r'))
ds_version = deepspeed.__version__
elastic_config = ds_config['elasticity']
print('------------------------------------------')
print("Elasticity config:")
print('------------------------------------------')
print(json.dumps(elastic_config, indent=4, sort_keys=True))
if args.world_size > 0:
final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version, world_size=args.world_size)
print('------------------------------------------')
print(f"Calculated results for world size {args.world_size}:")
print('------------------------------------------')
print(f'final_batch_size .... {final_batch_size}')
print(f'valid_gpus .......... {valid_gpus}')
print(f'micro_batch_size .... {micro_batch_size}')
else:
final_batch_size, valid_gpus = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version)
print('------------------------------------------')
print("Calculated results:")
print('------------------------------------------')
print(f'final_batch_size .... {final_batch_size}')
print(f'valid_gpus .......... {valid_gpus}')
from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
import json
from .constants import *
class ElasticityError(Exception):
"""
Base exception for all elasticity related errors
"""
pass
class ElasticityConfigError(ElasticityError):
"""
Elasticity configuration error
"""
pass
class ElasticityIncompatibleWorldSize(ElasticityError):
"""
Attempting to run a world size that is incompatible with a given elastic config
"""
pass
class ElasticityConfig:
"""
Elastic config object, constructed from a param dictionary that only contains elastic
config parameters, example below:
If elasticity is enabled, user must specify (at least) max_train_batch_size
and micro_batch_sizes.
{
"enabled": true,
"max_train_batch_size": 2000,
"micro_batch_sizes": [2,4,6],
"min_gpus": 1,
"max_gpus" : 10000
"min_time": 20
"ignore_non_elastic_batch_info": false
"version": 0.1
}
"""
def __init__(self, param_dict):
self.enabled = param_dict.get(ENABLED, ENABLED_DEFAULT)
if self.enabled:
if MAX_ACCEPTABLE_BATCH_SIZE in param_dict:
self.max_acceptable_batch_size = param_dict[MAX_ACCEPTABLE_BATCH_SIZE]
else:
raise ElasticityConfigError(
f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}")
if MICRO_BATCHES in param_dict:
self.micro_batches = param_dict[MICRO_BATCHES]
else:
raise ElasticityConfigError(f"Elasticity config missing {MICRO_BATCHES}")
else:
self.max_acceptable_batch_size = param_dict.get(
MAX_ACCEPTABLE_BATCH_SIZE,
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT)
self.micro_batches = param_dict.get(MICRO_BATCHES, MICRO_BATCHES_DEFAULT)
self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT)
self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT)
self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT)
self.version = param_dict.get(VERSION, VERSION_DEFAULT)
self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH,
PREFER_LARGER_BATCH_DEFAULT)
self.ignore_non_elastic_batch_info = param_dict.get(
IGNORE_NON_ELASTIC_BATCH_INFO,
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
def repr(self):
return self.__dict__
def __repr__(self):
return json.dumps(self.__dict__, sort_keys=True, indent=4)
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
#########################################
# Elasticity
#########################################
''' Elasticity Utility in DeepSpeed can be used to create highly elastic jobs compatible
with a large number of GPUs. For elastic jobs, DeepSpeed will provide a batch size that
can support a large number of GPUs based on the user specified parameters
'''
FORMAT = '''
Elasticity should be enabled as:
"elasticity": {
"enabled": true,
"max_train_batch_size": 2000,
"micro_batch_sizes": [2,4,6],
"min_gpus": 1,
"max_gpus" : 10000
"min_time": 20,
"prefer_larger_batch": true,
"ignore_non_elastic_batch_info": false,
"version": 0.1
}
'''
ELASTICITY = 'elasticity'
# Current elasticity version
LATEST_ELASTICITY_VERSION = 0.1
ENABLED = 'enabled'
ENABLED_DEFAULT = False
# Max acceptable train_batch_size
MAX_ACCEPTABLE_BATCH_SIZE = 'max_train_batch_size'
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT = 2000
# Acceptable micro batch sizes, same as train_micro_batch_size_per_gpu
MICRO_BATCHES = 'micro_batch_sizes'
MICRO_BATCHES_DEFAULT = [2, 4, 6]
# Min/max of GPUs to search over
MIN_GPUS = 'min_gpus'
MIN_GPUS_DEFAULT = 1
MAX_GPUS = 'max_gpus'
MAX_GPUS_DEFAULT = 10000
# Minimum running time (minutes) before the scheduler will scale us
MIN_TIME = "min_time"
MIN_TIME_DEFAULT = "20"
# When finding a suitable batch size, attempt to find one that is closest
# to the max train batch size given.
PREFER_LARGER_BATCH = 'prefer_larger_batch'
PREFER_LARGER_BATCH_DEFAULT = True
# In order to reduce confusion, if elastic mode is enabled we
# require (via assert) that no batch info is set outside of the
# elastic config. You can turn off this assert via this config
# but keep in mind that all batch info defined outside the
# elastic mode *will be ignored*.
IGNORE_NON_ELASTIC_BATCH_INFO = 'ignore_non_elastic_batch_info'
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT = False
# Version of elastic logic to use
VERSION = "version"
VERSION_DEFAULT = LATEST_ELASTICITY_VERSION
# Minimum deepspeed version to use elasticity
MINIMUM_DEEPSPEED_VERSION = "0.3.8"
# Environment variable storing elastic config from resource scheduler
DEEPSPEED_ELASTICITY_CONFIG = "DEEPSPEED_ELASTICITY_CONFIG"
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
import os
import re
import json
import numpy as np
from .config import ElasticityConfig, ElasticityConfigError, ElasticityError, \
ElasticityIncompatibleWorldSize
from .constants import ELASTICITY, ENABLED, ENABLED_DEFAULT, LATEST_ELASTICITY_VERSION, \
MINIMUM_DEEPSPEED_VERSION, IGNORE_NON_ELASTIC_BATCH_INFO, \
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT, DEEPSPEED_ELASTICITY_CONFIG
from ..git_version_info import version as __version__
from ..utils import logger
# Thirty eight smallest highly composite numbers. The list should
# be enough to support up to 720K batch size.
HCN_LIST = [
1,
2,
4,
6,
12,
24,
36,
48,
60,
120,
180,
240,
360,
720,
840,
1260,
1680,
2520,
5040,
7560,
10080,
15120,
20160,
25200,
27720,
45360,
50400,
55440,
83160,
110880,
166320,
221760,
277200,
332640,
498960,
554400,
665280,
720720
]
def get_candidate_batch_sizes(base_list, max_acceptable_batch_size):
candidate_batch_size = []
#brute force is fine here. We are working with very small lists
for base in base_list:
batch_size = base
for hcn in HCN_LIST:
new_batch_size = base * hcn
if new_batch_size > max_acceptable_batch_size:
break
batch_size = new_batch_size
candidate_batch_size.append(batch_size)
return list(set(candidate_batch_size))
def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus):
valid_gpus = []
for micro_batch in micro_batches:
if batch_size % micro_batch == 0:
max_gpus = batch_size // micro_batch
if max_gpus >= min_valid_gpus and max_gpus <= max_valid_gpus:
valid_gpus.append(max_gpus)
for i in range(1, max_gpus // 2 + 1):
if max_gpus % i == 0:
if i >= min_valid_gpus and i <= max_valid_gpus:
valid_gpus.append(i)
valid_gpus = set(valid_gpus)
valid_gpus = sorted(list(valid_gpus))
return valid_gpus
def get_best_candidates(candidate_batch_sizes,
micro_batches,
min_gpus,
max_gpus,
prefer_larger):
max_valid_gpus = 0
valid_gpus = None
final_batch_size = int(min(micro_batches))
for batch_size in candidate_batch_sizes:
current_valid_gpus = get_valid_gpus(batch_size,
micro_batches,
min_gpus,
max_gpus)
if (len(current_valid_gpus) > max_valid_gpus
or (len(current_valid_gpus) == max_valid_gpus and
((prefer_larger and batch_size > final_batch_size) or
(not prefer_larger and batch_size < final_batch_size)))):
max_valid_gpus = len(current_valid_gpus)
valid_gpus = current_valid_gpus
final_batch_size = batch_size
return final_batch_size, valid_gpus
def _get_compatible_gpus_v01(micro_batches,
max_acceptable_batch_size,
min_gpus=None,
max_gpus=None,
prefer_larger=True):
'''We use two heuristics to compute the batch size
1. We use the Lowest Common Multiple of the micro-batches
as the base batch size and scale it by a HCN such that the result is
the largest batch size less than the max_acceptable batch size
2. We use each of the micro batches as a base and scale it
by a HCN such that the result is the largest batch size less than the
max_acceptable batch size.
We then use brute force to count the number of compatible GPU count for
each of the aforementioned cases, and return the batch size with the most number of
compatible GPU counts in the min-max GPU range if provided, other wise
we return the batch size with the most number of total compatible GPU counts.
Returns:
final_batch_size
valid_gpus
'''
if min_gpus is None:
min_gpus = int(1)
if max_gpus is None:
max_gpus = int(max_acceptable_batch_size / min(micro_batches))
assert all(mb <= max_acceptable_batch_size for mb in micro_batches ), \
f"All micro batches must be less than \
or equal to max_acceptable_batch_size: {max_acceptable_batch_size}"
lcm = np.lcm.reduce(micro_batches)
base_list = []
base_list.extend(micro_batches)
base_list.append(lcm)
candidate_batch_sizes = get_candidate_batch_sizes(base_list,
max_acceptable_batch_size)
final_batch_size, valid_gpus = get_best_candidates(
candidate_batch_sizes,
micro_batches,
min_gpus,
max_gpus,
prefer_larger)
return final_batch_size, valid_gpus
def _parse_version(version_str):
'''Parse a version string and extract the major and minor versions (and possibly patch version).'''
matched = re.search('^(\d+)\.(\d+)\.(\d+)', version_str)
if matched:
return int(matched.group(1)), int(matched.group(2)), int(matched.group(3))
else:
matched = re.search('^(\d+)\.(\d+)', version_str)
assert matched != None, "Unable to parse version number, expecting" \
f"major.minor[.patch] format but received {version_str}"
return int(matched.group(1)), int(matched.group(2)), 0
def _compatible_ds_version_check(target_deepspeed_version: str):
min_major, min_minor, min_patch = _parse_version(MINIMUM_DEEPSPEED_VERSION)
trg_major, trg_minor, trg_patch = _parse_version(target_deepspeed_version)
err_str = f"Target deepspeed version of {target_deepspeed_version} is not compatible " \
f"with minimum version {MINIMUM_DEEPSPEED_VERSION} supporting elasticity."
if trg_major < min_major:
raise ElasticityError(err_str)
if trg_minor < min_minor:
raise ElasticityError(err_str)
if trg_patch < min_patch:
raise ElasticityError(err_str)
return True
def elasticity_enabled(ds_config: dict):
if ELASTICITY not in ds_config:
return False
return ds_config[ELASTICITY].get(ENABLED, ENABLED_DEFAULT)
def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict):
"""
Ensure the resource scheduler saw the same elastic config we are using at runtime
"""
if DEEPSPEED_ELASTICITY_CONFIG in os.environ:
scheduler_elastic_config_dict = json.loads(
os.environ[DEEPSPEED_ELASTICITY_CONFIG])
scheduler_elastic_config = ElasticityConfig(scheduler_elastic_config_dict)
runtime_elastic_config = ElasticityConfig(runtime_elastic_config_dict)
err_str = "Elastic config '{}={}' seen by resource scheduler does not match config passed to runtime {}={}"
if runtime_elastic_config.max_acceptable_batch_size != scheduler_elastic_config.max_acceptable_batch_size:
raise ElasticityConfigError(
err_str.format('max_acceptable_batch_size',
scheduler_elastic_config.max_acceptable_batch_size,
'max_acceptable_batch_size',
runtime_elastic_config.max_acceptable_batch_size))
if runtime_elastic_config.micro_batches != scheduler_elastic_config.micro_batches:
raise ElasticityConfigError(
err_str.format('micro_batches',
scheduler_elastic_config.micro_batches,
'micro_batches',
runtime_elastic_config.micro_batches))
if runtime_elastic_config.version != scheduler_elastic_config.version:
raise ElasticityConfigError(
err_str.format('version',
scheduler_elastic_config.version,
'version',
runtime_elastic_config.version))
else:
logger.warning("Unable to find DEEPSPEED_ELASTICITY_CONFIG environment variable, cannot " \
"guarantee resource scheduler will scale this job using compatible GPU counts.")
def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world_size=0):
"""Core deepspeed elasticity API. Given an elastic config (similar to the example below)
DeepSpeed will compute a total train batch size corresponding valid GPU count list that
provides a high level of elasticity. Elasticity in this case means we are safe to scale
the training job up/down across the GPU count list *without* any negative impacts on
training convergence. This is achievable primarily due to DeepSpeed's gradient accumulation
feature which allows us to decompose a global training batch size into:
micro-batch-size * gradient-accumulation-steps * world-size.
"elasticity": {
"enabled": true,
"max_train_batch_size": 2000,
"micro_batch_sizes": [2,4,6],
"min_gpus": 1,
"max_gpus" : 10000
"min_time": 20
"version": 0.1
}
Intended to be called both by scheduling infrastructure and deepspeed runtime.
For the same `ds_config` we should return deterministic results.
Args:
ds_config (dict): DeepSpeed config dictionary/json
target_deepspeed_version (str): When called from scheduling
infrastructure we want to ensure that the target deepspeed version is
compatible with the elasticity version used in the backend.
world_size (int, optional): Intended/current world size, will do some sanity
checks to ensure world size is actually valid with the config.
Raises:
ElasticityConfigError: Missing required elasticity config or elasticity disabled
ElasticityError: If target deepspeed version is not compatible with current version
Returns:
final_batch_size (int): total batch size used for training
valid_gpus (list(int)): list of valid GPU counts with this config
micro_batch_size (int, optional): if world_size is provided will return
specific micro batch size
"""
if not isinstance(ds_config, dict):
raise ValueError("Expected ds_config to be a dictionary but received " \
f"a {type(ds_config)}, containing: {ds_config}")
if ELASTICITY not in ds_config:
raise ElasticityConfigError(f"'{ELASTICITY}' is missing from config json," \
" please add it if running an elastic training job.")
elastic_config_dict = ds_config[ELASTICITY]
if not elastic_config_dict.get(ENABLED, ENABLED_DEFAULT):
raise ElasticityConfigError("Elasticity is disabled, please enable it " \
"('enabled':true) if running an elastic training job.")
elastic_config = ElasticityConfig(elastic_config_dict)
if float(elastic_config.version) > LATEST_ELASTICITY_VERSION:
raise ElasticityConfigError("Attempting to run elasticity version " \
f"{elastic_config.version} but runtime only supports up " \
f"to {LATEST_ELASTICITY_VERSION}")
# Ensure target deepspeed version works with intended elasticity version
if not _compatible_ds_version_check(target_deepspeed_version):
raise ElasticityError("Unable to run elasticity on target deepspeed version of" \
f" {target_deepspeed_version}, currently {__version__}")
if float(elastic_config.version) == 0.1:
final_batch_size, valid_gpus = _get_compatible_gpus_v01(
micro_batches=elastic_config.micro_batches,
max_acceptable_batch_size=elastic_config.max_acceptable_batch_size,
min_gpus=elastic_config.min_gpus,
max_gpus=elastic_config.max_gpus,
prefer_larger=elastic_config.prefer_larger_batch_size)
# ensure batch size is int dtype
final_batch_size = int(final_batch_size)
else:
raise NotImplementedError(
f"Unable to find elastic logic for version: {elastic_config.version}")
if world_size > 0:
if world_size not in valid_gpus:
raise ElasticityIncompatibleWorldSize(f"World size ({world_size}) is not valid " \
f"with the current list of valid GPU counts: {valid_gpus}")
# Pick largest valid micro batch size
micro_batch_size = None
for mbsz in sorted(list(set(elastic_config.micro_batches)), reverse=True):
if final_batch_size // world_size % mbsz == 0:
micro_batch_size = mbsz
break
assert micro_batch_size is not None, "Unable to find divisible micro batch size" \
f" world_size={world_size}, final_batch_size={final_batch_size}, and " \
f" micro_batches={elastic_config.micro_batches}."
return final_batch_size, valid_gpus, micro_batch_size
return final_batch_size, valid_gpus
......@@ -6,13 +6,21 @@ Licensed under the MIT license.
import torch
import json
import copy
from deepspeed.runtime.constants import *
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from deepspeed.runtime.config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from deepspeed.runtime.zero.constants import *
from deepspeed.runtime.activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
from deepspeed.utils import logger
from .constants import *
from .fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE
from .config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys
from .zero.config import DeepSpeedZeroConfig
from .zero.constants import *
from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
from ..git_version_info import version as __version__
from ..utils import logger
from ..elasticity import elasticity_enabled, compute_elastic_config, ensure_immutable_elastic_config
from ..elasticity.config import ElasticityConfigError
from ..elasticity.constants import ELASTICITY, IGNORE_NON_ELASTIC_BATCH_INFO, \
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT
TENSOR_CORE_ALIGN_SIZE = 8
......@@ -504,6 +512,59 @@ class DeepSpeedConfig(object):
self.global_rank = 0
self.world_size = 1
# If elastic-mode enabled, update compute + update _param_dict
self.elasticity_enabled = elasticity_enabled(self._param_dict)
if self.elasticity_enabled:
logger.info("DeepSpeed elasticity support enabled")
final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(
ds_config=self._param_dict,
target_deepspeed_version=__version__,
world_size=self.world_size)
elastic_dict = self._param_dict[ELASTICITY]
# Ensure the resource scheduler saw the same elastic config we are using at runtime
ensure_immutable_elastic_config(runtime_elastic_config_dict=elastic_dict)
ignore_non_elastic_batch_info = elastic_dict.get(
IGNORE_NON_ELASTIC_BATCH_INFO,
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
if not ignore_non_elastic_batch_info:
batch_params = [
TRAIN_BATCH_SIZE,
TRAIN_MICRO_BATCH_SIZE_PER_GPU,
GRADIENT_ACCUMULATION_STEPS
]
if any(map(lambda t: t in self._param_dict, batch_params)):
raise ElasticityConfigError("One or more batch related parameters were found in your " \
f"ds_config ({TRAIN_BATCH_SIZE}, {TRAIN_MICRO_BATCH_SIZE_PER_GPU}, and/or " \
f"{GRADIENT_ACCUMULATION_STEPS}). These parameters *will not be used* since " \
"elastic training is enabled, which takes control of these parameters. " \
"If you want to supress this error (the parameters will be silently ignored) " \
f"please set {IGNORE_NON_ELASTIC_BATCH_INFO}':true in your elasticity config.")
# micro_bsz * world_size * gas = total_batch_size
# gas = total_batch_size // (micro_bsz * world_size)
gradient_accu_steps = final_batch_size // (micro_batch_size *
self.world_size)
if TRAIN_BATCH_SIZE in self._param_dict:
logger.warning("[Elasticity] overriding training_batch_size: " \
f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}")
if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self._param_dict:
logger.warning("[Elasticity] overriding train_micro_batch_size_per_gpu: " \
f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}")
if GRADIENT_ACCUMULATION_STEPS in self._param_dict:
logger.warning("[Elasticity] overriding gradient_accumulation_steps: "\
f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}")
logger.info(f"[Elasticity] valid GPU counts: {valid_gpus}")
self._param_dict[TRAIN_BATCH_SIZE] = final_batch_size
self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = micro_batch_size
self._param_dict[GRADIENT_ACCUMULATION_STEPS] = gradient_accu_steps
self._initialize_params(self._param_dict)
self._configure_train_batch_size()
self._do_sanity_check()
......
......@@ -13,6 +13,10 @@ def get_scalar_param(param_dict, param_name, param_default_value):
return param_dict.get(param_name, param_default_value)
def get_list_param(param_dict, param_name, param_default_value):
return param_dict.get(param_name, param_default_value)
def dict_raise_error_on_duplicate_keys(ordered_pairs):
"""Reject duplicate keys."""
d = dict((k, v) for k, v in ordered_pairs)
......
......@@ -137,6 +137,10 @@ class DeepSpeedEngine(Module):
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
if mpu is not None:
assert not self.elasticity_enabled(), "Elasticity is not currently supported" \
" with model parallelism."
self._set_distributed_vars()
if self.tensorboard_enabled() and self.global_rank == 0:
......@@ -194,6 +198,22 @@ class DeepSpeedEngine(Module):
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten
def get_batch_info(self):
""" Get all training batch related settings.
Returns:
train_batch_size (int): The effective training batch size. This is the amount of data
samples that leads to one step of model update.
train_micro_batch_size_per_gpu (int): Batch size to be processed by one GPU in one
step (without gradient accumulation).
gradient_accumulation_steps (int): Number of training steps to accumulate gradients
before averaging and applying them.
"""
return self.train_batch_size, self.train_micro_batch_size_per_gpu, self.gradient_accumulation_steps
def elasticity_enabled(self):
return self._config.elasticity_enabled
def pld_enabled(self):
return self._config.pld_enabled
......@@ -1224,10 +1244,13 @@ class DeepSpeedEngine(Module):
if tag is None:
latest_path = os.path.join(load_dir, 'latest')
assert os.path.isfile(latest_path), f"Unable to find latest file at {latest_path}, if trying to load latest " \
"checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint."
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
logger.warning(f"Unable to find latest file at {latest_path}, if trying to load latest " \
"checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint.")
return None, None
load_path, client_states = self._load_checkpoint(load_dir,
tag,
......
......@@ -54,6 +54,8 @@ class PipelineEngine(DeepSpeedEngine):
# We schedule the all-reduces, so disable it in super().backward()
self.enable_backward_allreduce = False
assert not self.elasticity_enabled(), "Elasticity is not currently supported" \
" with pipeline parallelism."
# pipeline step for logging
self.log_batch_step_id = -1
......
......@@ -33,7 +33,9 @@ def installed_cuda_version():
def get_default_compute_capatabilities():
compute_caps = DEFAULT_COMPUTE_CAPABILITIES
if installed_cuda_version()[0] >= 11:
import torch.utils.cpp_extension
if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version(
)[0] >= 11:
compute_caps += ";8.0;8.6"
return compute_caps
......
......@@ -3,3 +3,4 @@ torchvision>=0.4.0
tqdm
tensorboardX==1.8
ninja
numpy
......@@ -184,7 +184,8 @@ setup(name='deepspeed',
'bin/deepspeed.pt',
'bin/ds',
'bin/ds_ssh',
'bin/ds_report'
'bin/ds_report',
'bin/ds_elastic'
],
classifiers=[
'Programming Language :: Python :: 3.6',
......
......@@ -757,7 +757,7 @@ def test_checkpoint_missing_latest(tmpdir):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
with pytest.raises(AssertionError):
model.load_checkpoint(tmpdir)
# should be no-op, since latest doesn't exist
model.load_checkpoint(tmpdir)
_helper(args=args, model=model, hidden_dim=hidden_dim)
import pytest
import deepspeed
from common import distributed_test
from deepspeed.git_version_info import version as ds_version
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict
base_ds_config = {
"elasticity": {
"enabled": True,
"max_train_batch_size": 10000,
"micro_batch_sizes": [8,
12,
16,
17],
"min_gpus": 32,
"max_gpus": 1500,
"min_time": 20,
"version": 0.1
}
}
def test_basic_10k():
ds_config = base_ds_config.copy()
final_batch_size, valid_gpus = deepspeed.elasticity.compute_elastic_config(
ds_config=ds_config,
target_deepspeed_version=ds_version)
for gpu_num in valid_gpus:
assert final_batch_size % gpu_num == 0, f"Batch {final_batch_size} is not divisible by GPU count {gpu_num}"
batch_per_gpu = final_batch_size // gpu_num
found_valid_mbsize = False
for mb in ds_config['elasticity']['micro_batch_sizes']:
if batch_per_gpu % mb == 0:
found_valid_mb = True
break
assert found_valid_mb, "No valid mb found"
assert len(valid_gpus) == 23
assert final_batch_size == 9792
def test_old_version():
ds_config = base_ds_config.copy()
with pytest.raises(deepspeed.elasticity.config.ElasticityError):
final_batch_size, valid_gpus = deepspeed.elasticity.compute_elastic_config(
ds_config=ds_config,
target_deepspeed_version="0.2")
def test_disabled():
ds_config = base_ds_config.copy()
ds_config['elasticity']['enabled'] = False
with pytest.raises(deepspeed.elasticity.config.ElasticityError):
final_batch_size, valid_gpus = deepspeed.elasticity.compute_elastic_config(
ds_config=ds_config,
target_deepspeed_version=ds_version)
def test_valid_world_size():
ds_config = base_ds_config.copy()
final_batch_size, valid_gpus, mbsize = deepspeed.elasticity.compute_elastic_config(
ds_config=ds_config,
target_deepspeed_version=ds_version,
world_size=64)
assert mbsize == 17
def test_invalid_world_size():
ds_config = base_ds_config.copy()
with pytest.raises(deepspeed.elasticity.config.ElasticityIncompatibleWorldSize):
final_batch_size, valid_gpus, mbsize = deepspeed.elasticity.compute_elastic_config(
ds_config=ds_config,
target_deepspeed_version=ds_version,
world_size=128)
def test_future_elastic_version():
ds_config = base_ds_config.copy()
ds_config['elasticity']['version'] = '0.2'
with pytest.raises(deepspeed.elasticity.config.ElasticityError):
deepspeed.elasticity.compute_elastic_config(ds_config=ds_config,
target_deepspeed_version=ds_version)
def test_missing_max_batch():
ds_config = base_ds_config.copy()
del ds_config['elasticity']['max_train_batch_size']
with pytest.raises(deepspeed.elasticity.config.ElasticityError):
deepspeed.elasticity.compute_elastic_config(ds_config=ds_config,
target_deepspeed_version=ds_version)
def test_missing_micro_batch():
ds_config = base_ds_config.copy()
del ds_config['elasticity']['micro_batch_sizes']
with pytest.raises(deepspeed.elasticity.config.ElasticityError):
deepspeed.elasticity.compute_elastic_config(ds_config=ds_config,
target_deepspeed_version=ds_version)
def test_empty_config():
ds_config = {"elasticity": {"enabled": True}}
with pytest.raises(deepspeed.elasticity.config.ElasticityError):
deepspeed.elasticity.compute_elastic_config(ds_config=ds_config,
target_deepspeed_version=ds_version)
def test_proper_mbsz():
ds_config = base_ds_config.copy()
ds_config["elasticity"]["max_train_batch_size"] = 32
ds_config["elasticity"]["micro_batch_sizes"] = [1, 2, 3, 7]
ds_config["elasticity"]["min_gpus"] = 1
final_batch_size, valid_gpus, mbsize = deepspeed.elasticity.compute_elastic_config(
ds_config=ds_config,
target_deepspeed_version=ds_version,
world_size=7)
assert mbsize == 3
def test_non_elastic_batch_params(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Lamb",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"elasticity": {
"enabled": True,
"max_train_batch_size": 4,
"micro_batch_sizes": [1,
2,
3,
4],
"min_gpus": 1,
"max_gpus": 4,
"min_time": 20,
"version": 0.1
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1, 2])
def _test_elastic(args, model, hidden_dim):
with pytest.raises(deepspeed.elasticity.config.ElasticityError):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
_test_elastic(args=args, model=model, hidden_dim=hidden_dim)
def test_non_elastic_batch_params_w_override(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Lamb",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"elasticity": {
"enabled": True,
"max_train_batch_size": 4,
"micro_batch_sizes": [1,
2,
3,
4],
"min_gpus": 1,
"max_gpus": 4,
"min_time": 20,
"version": 0.1,
"ignore_non_elastic_batch_info": True
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1, 2])
def _test_elastic(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
_test_elastic(args=args, model=model, hidden_dim=hidden_dim)
def test_elastic_config_changed(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Lamb",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"elasticity": {
"enabled": True,
"max_train_batch_size": 4,
"micro_batch_sizes": [1,
2,
3,
4],
"min_gpus": 1,
"max_gpus": 4,
"min_time": 20,
"version": 0.1,
"ignore_non_elastic_batch_info": True
}
}
import json, os
scheduler_elastic_config = config_dict.copy()
scheduler_elastic_config["elasticity"]["max_train_batch_size"] = 27
os.environ['DEEPSPEED_ELASTICITY_CONFIG'] = json.dumps(scheduler_elastic_config)
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1, 2])
def _test_elastic(args, model, hidden_dim):
with pytest.raises(deepspeed.elasticity.config.ElasticityError):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
_test_elastic(args=args, model=model, hidden_dim=hidden_dim)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册