__init__.py 6.8 KB
Newer Older
J
Jeff Rasley 已提交
1 2 3
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
4 5
import sys
import types
J
Jeff Rasley 已提交
6

7 8 9 10 11 12 13
from deepspeed.runtime.engine import DeepSpeedEngine
from deepspeed.runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from deepspeed.runtime.lr_schedules import add_tuning_arguments
from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.runtime.activation_checkpointing import checkpointing
from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from deepspeed.utils import logger
J
Jeff Rasley 已提交
14

J
Jeff Rasley 已提交
15
try:
16
    from deepspeed.git_version_info import version, git_hash, git_branch
J
Jeff Rasley 已提交
17
except ImportError:
18
    version = "0.0.0+unknown"
J
Jeff Rasley 已提交
19 20 21
    git_hash = None
    git_branch = None

22
# Export version information
23 24 25 26
version, __version_tag__ = version.split('+')
__version_major__ = int(version.split('.')[0])
__version_minor__ = int(version.split('.')[1])
__version_patch__ = int(version.split('.')[2])
27 28 29 30 31
__version__ = '.'.join(
    map(str,
        [__version_major__,
         __version_minor__,
         __version_patch__]))
32
__version__ = f"{__version__}+{__version_tag__}"
J
Jeff Rasley 已提交
33 34 35
__git_hash__ = git_hash
__git_branch__ = git_branch

36 37 38 39 40 41 42 43 44 45
# Provide backwards compatability with old deepspeed.pt module structure, should hopefully not be used
pt = types.ModuleType('pt', 'dummy pt module for backwards compatability')
deepspeed = sys.modules[__name__]
setattr(deepspeed, 'pt', pt)
setattr(deepspeed.pt, 'deepspeed_utils', deepspeed.runtime.utils)
sys.modules['deepspeed.pt'] = deepspeed.pt
sys.modules['deepspeed.pt.deepspeed_utils'] = deepspeed.runtime.utils
setattr(deepspeed.pt, 'deepspeed_config', deepspeed.runtime.config)
sys.modules['deepspeed.pt.deepspeed_config'] = deepspeed.runtime.config

J
Jeff Rasley 已提交
46 47 48 49 50 51 52 53

def initialize(args,
               model,
               optimizer=None,
               model_parameters=None,
               training_data=None,
               lr_scheduler=None,
               mpu=None,
54
               dist_init_required=None,
J
Jeff Rasley 已提交
55 56
               collate_fn=None,
               config_params=None):
57
    """Initialize the DeepSpeed Engine.
J
Jeff Rasley 已提交
58 59 60 61 62 63 64 65 66 67

    Arguments:
        args: a dictionary containing local_rank and deepspeed_config
            file location

        model: Required: nn.module class before apply any wrappers

        optimizer: Optional: a user defined optimizer, this is typically used instead of defining
            an optimizer in the DeepSpeed json config.

68
        model_parameters: Optional: An iterable of torch.Tensors or dicts.
J
Jeff Rasley 已提交
69 70 71 72 73 74 75 76
            Specifies what Tensors should be optimized.

        training_data: Optional: Dataset of type torch.utils.data.Dataset

        lr_scheduler: Optional: Learning Rate Scheduler Object. It should define a get_lr(),
            step(), state_dict(), and load_state_dict() methods

        mpu: Optional: A model parallelism unit object that implements
S
Shaden Smith 已提交
77
            get_{model,data}_parallel_{rank,group,world_size}()
J
Jeff Rasley 已提交
78

79 80
        dist_init_required: Optional: None will auto-initialize torch.distributed if needed,
            otherwise the user can force it to be initialized or not via boolean.
J
Jeff Rasley 已提交
81 82 83 84 85

        collate_fn: Optional: Merges a list of samples to form a
            mini-batch of Tensor(s).  Used when using batched loading from a
            map-style dataset.

86 87
    Returns:
        A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler``
J
Jeff Rasley 已提交
88

89
        * ``engine``: DeepSpeed runtime engine which wraps the client model for distributed training.
90

91 92
        * ``optimizer``: Wrapped optimizer if a user defined ``optimizer`` is supplied, or if
          optimizer is specified in json config else ``None``.
93

94 95
        * ``training_dataloader``: DeepSpeed dataloader if ``training_data`` was supplied,
          otherwise ``None``.
96

97 98
        * ``lr_scheduler``: Wrapped lr scheduler if user ``lr_scheduler`` is passed, or
          if ``lr_scheduler`` specified in JSON configuration. Otherwise ``None``.
J
Jeff Rasley 已提交
99
    """
C
Chunyang Wen 已提交
100 101 102 103 104 105
    logger.info(
        "DeepSpeed info: version={}, git-hash={}, git-branch={}".format(
            __version__,
            __git_hash__,
            __git_branch__),
    )
J
Jeff Rasley 已提交
106

107 108 109 110 111 112 113 114 115 116
    engine = DeepSpeedEngine(args=args,
                             model=model,
                             optimizer=optimizer,
                             model_parameters=model_parameters,
                             training_data=training_data,
                             lr_scheduler=lr_scheduler,
                             mpu=mpu,
                             dist_init_required=dist_init_required,
                             collate_fn=collate_fn,
                             config_params=config_params)
J
Jeff Rasley 已提交
117 118 119 120 121 122 123 124 125 126

    return_items = [
        engine,
        engine.optimizer,
        engine.training_dataloader,
        engine.lr_scheduler
    ]
    return tuple(return_items)


127 128 129 130 131 132 133
def _add_core_arguments(parser):
    r"""Helper (internal) function to update an argument parser with an argument group of the core DeepSpeed arguments.
        The core set of DeepSpeed arguments include the following:
        1) --deepspeed: boolean flag to enable DeepSpeed
        2) --deepspeed_config <json file path>: path of a json configuration file to configure DeepSpeed runtime.

        This is a helper function to the public add_config_arguments()
J
Jeff Rasley 已提交
134 135 136 137 138 139 140 141

    Arguments:
        parser: argument parser
    Return:
        parser: Updated Parser
    """
    group = parser.add_argument_group('DeepSpeed', 'DeepSpeed configurations')

142 143 144 145 146 147
    group.add_argument(
        '--deepspeed',
        default=False,
        action='store_true',
        help=
        'Enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)')
J
Jeff Rasley 已提交
148 149 150 151 152 153

    group.add_argument('--deepspeed_config',
                       default=None,
                       type=str,
                       help='DeepSpeed json configuration file.')

154 155 156 157 158 159 160 161
    group.add_argument(
        '--deepscale',
        default=False,
        action='store_true',
        help=
        'Deprecated enable DeepSpeed (helper flag for user code, no impact on DeepSpeed backend)'
    )

162 163 164 165
    group.add_argument('--deepscale_config',
                       default=None,
                       type=str,
                       help='Deprecated DeepSpeed json configuration file.')
J
Jeff Rasley 已提交
166 167 168 169 170 171 172 173 174

    group.add_argument(
        '--deepspeed_mpi',
        default=False,
        action='store_true',
        help=
        "Run via MPI, this will attempt to discover the necessary variables to initialize torch "
        "distributed from the MPI environment")

J
Jeff Rasley 已提交
175 176 177 178
    return parser


def add_config_arguments(parser):
179 180 181 182
    r"""Update the argument parser to enabling parsing of DeepSpeed command line arguments.
        The set of DeepSpeed arguments include the following:
        1) --deepspeed: boolean flag to enable DeepSpeed
        2) --deepspeed_config <json file path>: path of a json configuration file to configure DeepSpeed runtime.
J
Jeff Rasley 已提交
183 184 185 186 187 188

    Arguments:
        parser: argument parser
    Return:
        parser: Updated Parser
    """
189
    parser = _add_core_arguments(parser)
J
Jeff Rasley 已提交
190 191

    return parser